#!usr/bin/env/ python
# -*- coding:utf-8 -*-
"""
Author:Xiaoxu Zhang
Date:2024-05-22
References: https://zhuanlan.zhihu.com/p/698511539
"""
from kan import *
import torch
import matplotlib.pyplot as plt
torch.manual_seed(42)

if __name__ == '__main__':

    model = KAN(width=[2, 5, 1], grid=5, k=3, seed=0)

    # 创建数据集 f(x,y) = exp(sin(pi*x)+y^2)
    f = lambda x: torch.exp(torch.sin(torch.pi*x[:, [0]]) + x[:, [1]]**2)
    dataset = create_dataset(f, n_var=2, ranges=[-10, 10])
    model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=20.)
    model.plot()
    plt.show()

    dataset = create_dataset(f, n_var=2, ranges=[-10, 10])
    print()
